{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "GPU_id = 0\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_id)\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.9.0'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from time import time \n",
    "\n",
    "from fastai import *\n",
    "from fastai.basic_data import *\n",
    "from fastai.basic_data import *\n",
    "from fastai.tabular import *\n",
    "from fastai.basic_data import DataBunch\n",
    "from fastai.tabular import TabularModel\n",
    "\n",
    "import cudf\n",
    "\n",
    "from preproc import *\n",
    "from batchloader import *\n",
    "from helpers import get_mean_reciprocal_rank, roc_auc_score\n",
    "cudf.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext snakeviz\n",
    "# load snakeviz if you want to run profiling "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1> <center> <a id=batchdatabunch>New Data Bunch </a></center> </h1> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define a custom databunch fastai that takes a TensorBatchDataLoader instead of the usual torch DataLoader "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BatchDataBunch(DataBunch):\n",
    "    \n",
    "    @classmethod\n",
    "    def remove_tfm(cls, tfm:Callable)->None:\n",
    "        \"Remove `tfm` from `self.tfms`.\"\n",
    "        if tfm in cls.tfms: cls.tfms.remove(tfm)\n",
    "            \n",
    "    @classmethod\n",
    "    def add_tfm(cls,tfm:Callable)->None:\n",
    "        \"Add `tfm` to `self.tfms`.\"\n",
    "        cls.tfms.append(tfm)\n",
    "\n",
    "    \n",
    "    @classmethod\n",
    "    def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=64, val_bs=None, \n",
    "                      num_workers:int=defaults.cpus, device:torch.device=None,\n",
    "                      collate_fn:Callable=data_collate, tfms: List[Callable]=None, \n",
    "                       size:int=None, **kwargs)->'BatchDataBunch':\n",
    "        \n",
    "        \n",
    "        cls.tfms = listify(tfms)\n",
    "        \n",
    "        \n",
    "        val_bs = ifnone(val_bs, bs)\n",
    "        \n",
    "        datasets = [TensorBatchDataset(train_ds, batch_size=bs), \n",
    "                    TensorBatchDataset(valid_ds, batch_size=bs)]\n",
    "        \n",
    "        if valid_ds is not None:\n",
    "            cls.empty_val = False\n",
    "        else:\n",
    "            cls.empty_val = True\n",
    "            \n",
    "        if test_ds is not None:\n",
    "            datasets.append(TensorBatchDataset(test_ds, batch_size=bs))\n",
    "        else: \n",
    "            datasets.append(test_ds)\n",
    "        \n",
    "        cls.device = defaults.device if device is None else device\n",
    "        \n",
    "        dls = [BatchDataLoader(d, shuffle=s, pin_memory=False, drop_last=False, device=cls.device) for d,s in\n",
    "               zip(datasets,(True,False,False)) if d is not None]\n",
    "\n",
    "        cls.path = path \n",
    "        \n",
    "        cls.dls = dls\n",
    "    \n",
    "        \n",
    "        \n",
    "        assert not isinstance(dls[0],DeviceDataLoader)\n",
    "        \n",
    "        \n",
    "        # load batch in device \n",
    "        \n",
    "        if test_ds is not None:\n",
    "            cls.train_dl, cls.valid_dl, cls.test_dl = dls\n",
    "        else: \n",
    "            cls.train_dl, cls.valid_dl = dls\n",
    "            \n",
    "            \n",
    "        cls.path = Path(path)\n",
    "        return cls\n",
    "    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Config of the fastest workflow "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_cpu = False\n",
    "batch_size = 4096*50\n",
    "lr = 0.09"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process data : "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "read train used 9.11 seconds.\n",
      "   row_id  candidate_order   item_id  price  row_id_count       user_id  \\\n",
      "0    1236               17     73376    399            25  XFE5BT9RNTQW   \n",
      "1    1236               18  10348476    201            25  XFE5BT9RNTQW   \n",
      "2    1236               19    407711    425            25  XFE5BT9RNTQW   \n",
      "3    1236               20   9882224    153            25  XFE5BT9RNTQW   \n",
      "4    1236               21  10455202    813            25  XFE5BT9RNTQW   \n",
      "\n",
      "      session_id   timestamp  step    action_type  reference platform  \\\n",
      "0  62f66f7671352  1541073311     6  clickout item    5993048       US   \n",
      "1  62f66f7671352  1541073311     6  clickout item    5993048       US   \n",
      "2  62f66f7671352  1541073311     6  clickout item    5993048       US   \n",
      "3  62f66f7671352  1541073311     6  clickout item    5993048       US   \n",
      "4  62f66f7671352  1541073311     6  clickout item    5993048       US   \n",
      "\n",
      "             city   device    current_filters  clickout_missing  target  \n",
      "0  Nashville, USA  desktop  Focus on Distance                 0       0  \n",
      "1  Nashville, USA  desktop  Focus on Distance                 0       0  \n",
      "2  Nashville, USA  desktop  Focus on Distance                 0       0  \n",
      "3  Nashville, USA  desktop  Focus on Distance                 0       0  \n",
      "4  Nashville, USA  desktop  Focus on Distance                 0       0  \n",
      "get variables names used 0.00 seconds.\n",
      "processing train used 5.18 seconds.\n",
      "read test used 5.64 seconds.\n",
      "processing test used 5.36 seconds.\n",
      "read valid used 6.58 seconds.\n",
      "processing valid used 7.10 seconds.\n",
      "The whole processing used 29.16 seconds.\n"
     ]
    }
   ],
   "source": [
    "# %%snakeviz \n",
    "# uncomment the line above to generate the snakeviz profile of preprocessing \n",
    "\n",
    "data_path = '../cache/'\n",
    "TEST = 'test'\n",
    "VALID = 'valid'\n",
    "TRAIN = 'train'\n",
    "\n",
    "start0 = time()\n",
    "data = {}\n",
    "\n",
    "############################\n",
    "#                          #\n",
    "# Fit processing train set #\n",
    "#                          #\n",
    "############################\n",
    "start = time()\n",
    "path = os.path.join(data_path,TRAIN+'.parquet' )\n",
    "ds = cudf.read_parquet(path)\n",
    "print(f\"read {TRAIN} used {time()-start:.2f} seconds.\")\n",
    "\n",
    "\n",
    "# get variable names \n",
    "start = time()\n",
    "cat_names = ['user_id','item_id','platform','city','device','current_filters'] + [i for i in ds.columns if i.startswith('is_')]\n",
    "cont_names = ['price','candidate_order'] + [i for i in ds.columns if i.startswith('count') or 'rank' in i or i.startswith('delta_')]\n",
    "print(f\"get variables names used {time()-start:.2f} seconds.\")\n",
    "\n",
    "# init the processing class \n",
    "proc = PreprocessDF(cat_names=cat_names, cont_names=cont_names, label_name='target', to_cpu=to_cpu)\n",
    "\n",
    "# Fit training \n",
    "start = time()\n",
    "x, y = proc.preproc_dataframe(ds, mode=TRAIN)\n",
    "print(f\"processing {TRAIN} used {time()-start:.2f} seconds.\")\n",
    "del ds\n",
    "data[TRAIN] = (x, y)\n",
    "\n",
    "############################\n",
    "#                          #\n",
    "# Transform test and valid #\n",
    "#                          #\n",
    "############################  \n",
    "ds_name = [TEST, VALID]\n",
    "for name in ds_name:\n",
    "    path = os.path.join(data_path,name+'.parquet' )\n",
    "    ds = cudf.read_parquet(path)\n",
    "\n",
    "    print(f\"read {name} used {time()-start:.2f} seconds.\")\n",
    "    start = time()\n",
    "    x, y = proc.preproc_dataframe(ds, mode=name)\n",
    "    print(f\"processing {name} used {time()-start:.2f} seconds.\")\n",
    "    data[name] = (x, y)\n",
    "    del ds\n",
    "\n",
    "print(f\"The whole processing used {time()-start0:.2f} seconds.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fastai training "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = [data['train'][0][0], data['train'][0][1], data['train'][1].long()]\n",
    "validation = [data['valid'][0][0], data['valid'][0][1], data['valid'][1].long()]\n",
    "test = [data['test'][0][0], data['test'][0][1], data['test'][1].long()]\n",
    "databunch = BatchDataBunch.create(train, validation, device='cuda', bs=batch_size)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.153928</td>\n",
       "      <td>0.145674</td>\n",
       "      <td>00:14</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " \n",
      "*** Profile stats marshalled to file '/tmp/tmpwq0g5bw1'. \n",
      "Embedding SnakeViz in this document...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<iframe id='snakeviz-ce8375cc-d69a-11e9-ae7d-0242ac110002' frameborder=0 seamless width='100%' height='1000'></iframe>\n",
       "<script>document.getElementById(\"snakeviz-ce8375cc-d69a-11e9-ae7d-0242ac110002\").setAttribute(\"src\", \"http://\" + document.location.hostname + \":8080/snakeviz/%2Ftmp%2Ftmpwq0g5bw1\")</script>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%snakeviz\n",
    "emb_sz = [(938604, 16), (903867, 16), (56, 4), (32763, 8), (4, 1), (27842, 8)]  \n",
    "\n",
    "model = TabularModel(emb_szs = emb_sz, n_cont=len(cont_names), out_sz=2, layers=[64, 32]).cuda()\n",
    "\n",
    "learn =  Learner(databunch, model, metrics=None)\n",
    "\n",
    "learn.loss_func = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "start = time()\n",
    "learn.fit_one_cycle(1, lr)\n",
    "t_final = time() - start "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get validation metrics \n",
    "path = os.path.join(data_path,VALID+'.parquet' )\n",
    "ds = pd.read_parquet(path)\n",
    "yp,y_valid = learn.get_preds(databunch)\n",
    "cv = ds[['row_id','reference','item_id', 'target']].copy()\n",
    "cv['prob'] = yp.numpy()[:,1]\n",
    "cv = cv.sort_values(by=['row_id','prob'],ascending=False)\n",
    "auc = roc_auc_score(y_valid.numpy().ravel(),yp.numpy()[:,1])\n",
    "mean_reciprocal_rank = get_mean_reciprocal_rank(cv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the mrr of the best mdodel is: 0.46830685515870263 \n",
      "the auc of the best mdodel is: 0.8020382039050711 \n",
      "the best mdodel's training time is 14.842254400253296 \n"
     ]
    }
   ],
   "source": [
    "print(\"the mrr of the best mdodel is: %s \" %mean_reciprocal_rank)\n",
    "\n",
    "print(\"the auc of the best mdodel is: %s \" %auc)\n",
    "\n",
    "print(\"the best mdodel's training time is %s \" %t_final)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
